Skip to content

Conditional Monge Gap: JIT-compatible loss + training estimator#679

Open
DhruvaRajwade wants to merge 6 commits intoott-jax:mainfrom
DhruvaRajwade:pr-605-clean
Open

Conditional Monge Gap: JIT-compatible loss + training estimator#679
DhruvaRajwade wants to merge 6 commits intoott-jax:mainfrom
DhruvaRajwade:pr-605-clean

Conversation

@DhruvaRajwade
Copy link
Copy Markdown

Conditional Monge Gap: JIT-compatible loss + training estimator

The CMonge paper has been accepted to Nature Machine Intelligence. This PR picks up from #605 and fixes the JIT issue + adds a training estimator.

Changes

cmonge_gap_from_samples -- replaced the jnp.unique loop (breaks jax.jit) with _segment_interface (pad + vmap), following @michalk8's suggestion. Two new required-for-JIT params: num_segments, max_measure_size (consistent with segment_sinkhorn). A logger.warning fires when any condition is padded >10x its actual size, since heavy padding can cause small numerical differences vs non-padded Sinkhorn.

ConditionalMongeGapEstimator -- training wrapper mirroring MongeGapEstimator for conditional maps T(x, c):

loss = fitting(T(x,c), y) + lambda * regularizer(x, T(x,c), labels)
       |__ sinkdiv            |__ cmonge_gap_from_samples

A potential precision tradeoff remains: When all conditions have the same n_k, the segment-based result matches monge_gap_from_samples to ~1e-7. With unequal n_k, smaller conditions are zero-padded to max_measure_size, which changes the Sinkhorn geometry slightly. This is inherent to the segment/vmap approach -- the trade-off is JIT compatibility.

Tests (26 passing)

All tests mirror monge_gap_test.py patterns wherever applicable: non-negativity (random + neural map targets), JIT consistency, cost function variants, estimator convergence. Three additional equivalence tests verify cmonge_gap = mean(monge_gap_k) for equal-size conditions, document the padding effect for unequal sizes, and check monotonic gap ordering by difficulty.

pytest tests/neural/methods/conditional_monge_gap_test.py -v  # 26 tests, ~80s

Usage example

from ott import datasets
from ott.neural.methods.conditional_monge_gap import (
    ConditionalMongeGapEstimator, cmonge_gap_from_samples,
)
from ott.neural.networks.conditional_perturbation_network import (
    ConditionalPerturbationNetwork,
)
from ott.tools.sinkhorn_divergence import sinkdiv
import optax

num_cond, dim_data = 5, 25
train_ds, valid_ds, _, n_cond, max_ms = (
    datasets.create_conditional_gaussian_mixture_samplers(
        num_conditions=num_cond, dim=dim_data,
        train_batch_size=150, valid_batch_size=150,
    )
)

fitting_loss = lambda mapped, target: (sinkdiv(x=mapped, y=target)[0], None)
regularizer = lambda src, mapped, labels: (
    cmonge_gap_from_samples(src, mapped, labels,
        num_segments=n_cond, max_measure_size=max_ms), None)

model = ConditionalPerturbationNetwork(
    dim_hidden=[64, 64], dim_data=dim_data, dim_cond=num_cond,
    dim_cond_map=(32,), is_potential=False,
    context_entity_bonds=((0, num_cond),), num_contexts=1,
)
solver = ConditionalMongeGapEstimator(
    dim_data=dim_data, model=model,
    optimizer=optax.adam(learning_rate=1e-4),
    fitting_loss=fitting_loss, regularizer=regularizer,
    regularizer_strength=5.0, num_train_iters=2000,
    logging=True, valid_freq=50,
)
state, logs = solver.train_map_estimator(*train_ds, *valid_ds)
image

…IT compatibility

The original cmonge_gap_from_samples used a Python for-loop over
jnp.unique(condition), which breaks JAX JIT compilation since
jnp.unique returns a dynamically-sized array.

Replace with _segment_interface which pads per-condition point clouds
to a fixed max_measure_size and vmaps the per-segment Monge gap
computation. This makes the function fully JIT-compatible.

The eval_fn computes per-segment: displacement_cost - ent_reg_cost,
matching the definition in monge_gap_from_samples. Padded entries
have zero weight and do not affect the result.

New parameters num_segments and max_measure_size are required for JIT
(consistent with segment_sinkhorn API). Cost function parameters
(cost_fn, epsilon, relative_epsilon, scale_cost) are now explicit
rather than passed through **kwargs.
Add the estimator class that mirrors MongeGapEstimator but handles
condition-dependent neural maps T(x, c) with per-condition Monge gap
regularization via cmonge_gap_from_samples.

Changes:
- ConditionalMongeGapEstimator in conditional_monge_gap.py: training loop
  with 3-arg regularizer(source, mapped, labels), 4-iterator batch
  protocol, JIT-compiled step function
- ConditionalDataset + create_conditional_gaussian_mixture_samplers in
  datasets.py: synchronized 4-iterator data pipeline for testing
- Export conditional_perturbation_network from networks/__init__
- 16 tests: 8 unit tests for cmonge_gap_from_samples (non-negativity,
  JIT consistency, loop baseline match, identity vs random, cost fns,
  return shape) + 2 integration tests for the estimator (convergence,
  no-regularizer mode)
… tests

Add 5 new tests to TestConditionalMongeGap:
- test_non_negativity_neural_map: PotentialMLP-based targets
- test_different_costs_give_different_values: PNormP, RegTICost(L1), RegTICost(STVS)
- test_uniform_conditions_equals_averaged_monge_gap: exact equivalence proof
- test_unequal_conditions_shifts_average: structural properties with padding
- test_per_condition_gaps_reflect_difficulty: monotonic gap ordering
@marcocuturi marcocuturi requested a review from michalk8 March 19, 2026 13:49
@codecov
Copy link
Copy Markdown

codecov bot commented Mar 19, 2026

Codecov Report

❌ Patch coverage is 88.47926% with 25 lines in your changes missing coverage. Please review.
✅ Project coverage is 87.39%. Comparing base (7ecebc9) to head (9e51b40).
⚠️ Report is 4 commits behind head on main.

Files with missing lines Patch % Lines
...eural/networks/conditional_perturbation_network.py 69.81% 12 Missing and 4 partials ⚠️
src/ott/neural/methods/conditional_monge_gap.py 91.50% 4 Missing and 5 partials ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #679      +/-   ##
==========================================
+ Coverage   87.35%   87.39%   +0.04%     
==========================================
  Files          82       84       +2     
  Lines        8476     8690     +214     
  Branches      581      600      +19     
==========================================
+ Hits         7404     7595     +191     
- Misses        922      937      +15     
- Partials      150      158       +8     
Files with missing lines Coverage Δ
src/ott/datasets.py 97.93% <100.00%> (+2.93%) ⬆️
src/ott/neural/methods/conditional_monge_gap.py 91.50% <91.50%> (ø)
...eural/networks/conditional_perturbation_network.py 69.81% <69.81%> (ø)

... and 1 file with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

- Add logger.warning in cmonge_gap_from_samples when any condition is
  padded >10x its actual size (skipped under JIT via try/except)
- Add runtime timing to test_uniform_conditions_equals_averaged_monge_gap
  comparing segmented vs loop performance
- Rewrite PR_MESSAGE.md to ~half page with concise overview and tutorial plot
@jannisborn
Copy link
Copy Markdown

@michalk8 whenever you have time, feel free to let us know any feedback on this PR!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants